Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement an algsimp optimization for dot operation. #28170

Merged
merged 2 commits into from Apr 29, 2019

Conversation

BinFan
Copy link
Contributor

@BinFan BinFan commented Apr 26, 2019

The basic idea is that dot(reshape(transpose(A)), constant) can be replaced by dot(reshape(A), reshape(transpose(reshape(constant)))) if the effect of lhs transpose and reshape is to reorder elements in lhs contracting dims. We apply inverse reordering on the constant side, and then the inverse reordering can be constant folded.

@tensorflow-bot tensorflow-bot bot added the size:L CL Change Size: Large label Apr 26, 2019
@BinFan
Copy link
Contributor Author

BinFan commented Apr 26, 2019

@jlebar This is previously PR 27160 for the dot + transpose algsimp optimization. Somehow I screwed up the git log of that PR and I did not find a way to fix it within the PR. So I create this new one. Sorry about the duplication.

@jlebar
Copy link
Contributor

jlebar commented Apr 26, 2019

Huh, I am surprised a simple git push -f to your remote branch didn't fix the problem. But oh well. Would you be willing to give me a list of which if any comments are outstanding from the previous PR?

One is the question about the protobuf int64 class. We have an AsInt64Slice helper for exactly this.

@BinFan
Copy link
Contributor Author

BinFan commented Apr 26, 2019

Huh, I am surprised a simple git push -f to your remote branch didn't fix the problem. But oh well. Would you be willing to give me a list of which if any comments are outstanding from the previous PR?

One is the question about the protobuf int64 class. We have an AsInt64Slice helper for exactly this.

About the RepeatedField and int64 type comment, in this push I copied lhs_contracting_dims and rhs_contracting_dims out to a std::vector at the beginning and manipulate the vector since then, as we do not actually modify the dnums of the dot anyway.

Besides a few small typos here are the outstanding comments I can remember:

  1. In checking lhs reshape squishes some dims into one, change the for loop into std::find_if. Your original comment is to change it to absl::c_linear_search, but we need a custom comparison function here. Also add some comments about the check.
  2. In checking lhs transpose. Remove the unnecessary condition.
  3. In updating lhs contracting dims after "pulling in" lhs transpose, use ComposePermutation helper. I compute the permutation vector within contracting dims to workaround the check in the current ComposePermutation implementation.
  4. Simplify the comments before transforming rhs.
  5. In inverting reshape and transpose, simplify the implementation.

@gbaned gbaned self-assigned this Apr 26, 2019
@gbaned gbaned added the comp:xla XLA label Apr 26, 2019
@gbaned gbaned requested a review from jlebar April 26, 2019 06:36
jlebar
jlebar previously approved these changes Apr 26, 2019
Copy link
Contributor

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few minor things, otherwise this looks great!

reshape->operand(0)->shape(), reshape->shape());
CHECK_EQ(lhs_contracting_dims.size(), 1);
if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
(std::find_if(unmodified_dims.begin(), unmodified_dims.end(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl::c_find_if, but better would be absl::c_any_of.


// Check if reshape squishes some dims into one dim, and that this one
// dim is the dot's lhs contracting dim.
// The size of unmodified_dims should be N - 1, where N is the rank of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, if this line is a new paragraph, put a blank line before it. If it is not a new paragraph, flow it up with the previous line.

return nullptr;
}

// Check if reshape squishes some dims into one dim, and that this one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, "Check if...and that" has bad parallelism, best fix is probably "Check that...and that".


// Require single contracting dim to make the implementation easier to
// track contracting dims.
if (dnums.lhs_contracting_dimensions_size() != 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pull the vectors lhs_contracting_dims and rhs_contracting_dims below this if statement, then we can simply do

// Comment explaining why we're pulling these into vectors, I am still not sure what is the problem this solves, it seems to be more complex to have two copies of one piece of data?
std::vector<int64> lhs_contracting_dims = {dnums.lhs_contracting_dims[0]};

@tensorflow-bot tensorflow-bot bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Apr 26, 2019
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Apr 26, 2019
The basic idea is that dot(reshape(transpose(A)), constant) can be replaced by dot(reshape(A), reshape(transpose(reshape(constant)))) if the effect of lhs transpose and reshape is to reorder elements in contracting dims. We inverse the reorder on the constant side so that it can be constant folded.
@BinFan
Copy link
Contributor Author

BinFan commented Apr 26, 2019

Thanks for approving the PR. Just while you are reviewing it I made some changes to not copy out the lhs_contracting_dims before hand. If this looks better or equally well I can leave it this way. I will fix things you suggested as well.

@jlebar
Copy link
Contributor

jlebar commented Apr 26, 2019

I made some changes to not copy out the lhs_contracting_dims before hand. If this looks better or equally well I can leave it this way

Looks even better to me!

@BinFan
Copy link
Contributor Author

BinFan commented Apr 26, 2019

@jlebar My previous push seemed to overwrite your approval. I just update the PR to fix review comments. Basically this revision just changes absl::c_find_if to absl::c_any_of and fixes comment format. If this revision looks good to you, could you approve it again? Thanks!

Copy link
Contributor

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\o/ @gbaned would you be willing to merge this?

@tensorflow-bot tensorflow-bot bot added the kokoro:force-run Tests on submitted change label Apr 26, 2019
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Apr 26, 2019
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Apr 29, 2019
@gbaned
Copy link
Contributor

gbaned commented Apr 29, 2019

@jlebar sure. I'm taking care of this PR and helping to get it merged. Thanks you!

@gbaned gbaned added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process and removed ready to pull PR ready for merge process labels Apr 29, 2019
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Apr 29, 2019
@tensorflow-copybara tensorflow-copybara merged commit e7f555d into tensorflow:master Apr 29, 2019
PR Queue automation moved this from Assigned Reviewer to Merged Apr 29, 2019
tensorflow-copybara pushed a commit that referenced this pull request Apr 29, 2019
PiperOrigin-RevId: 245751800
pull bot pushed a commit to Cache-Cloud/tensorflow that referenced this pull request Apr 29, 2019
@jlebar
Copy link
Contributor

jlebar commented Apr 29, 2019

I had to roll this back due to a test failure; one of the CHECKs added here was failing.

Overall this is kind of a good thing, it means that a real model used in production is affected by this change. :) I will see if there's an easy fix that I can make, and if not I'll give you a testcase.

@BinFan
Copy link
Contributor Author

BinFan commented Apr 29, 2019

Thanks! Let me know if there is anything I can do on my side.

@jlebar
Copy link
Contributor

jlebar commented Apr 30, 2019

@BinFan would you be willing to check the following patch for me? The first testcase in here is the one that was crashing for me.

diff --git a/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier.cc
--- a/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1558,28 +1558,37 @@ AlgebraicSimplifierVisitor::OptimizeDotO
       reshape->operand(0)->shape(), reshape->shape());
   CHECK_EQ(lhs_contracting_dims.size(), 1);
   if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
       absl::c_any_of(unmodified_dims, [&](const std::pair<int64, int64>& p) {
         return p.second == lhs_contracting_dims[0];
       })) {
     return nullptr;
   }
-  // Virtually pull the reshape into the dot. Now the dot is equivalent to a
-  // new dot with "unsquished" lhs contracting dims. We don't need to actually
-  // create a new dot instruction. We can just keep track of lhs and
-  // lhs_contracting_dims.
-  CHECK_GT(reshape->operand(0)->shape().rank(), reshape->shape().rank());
-  lhs_contracting_dims.Resize(
-      reshape->operand(0)->shape().rank() - reshape->shape().rank() + 1,
-      lhs_contracting_dims[0]);
-  absl::c_iota(lhs_contracting_dims, lhs_contracting_dims[0]);
+
+  // Virtually pull the reshape into the dot so the dot operates on the
+  // transpose, with "unsquished" lhs contracting dims.  The new contracting
+  // dims are all of the dims that are modified by the reshape -- that is, every
+  // dimension that's not in `unmodified_dims[i].first`.
+  //
+  // (We don't need to actually create a new dot instruction. We can just keep
+  // track of lhs and lhs_contracting_dims.)
+  absl::flat_hash_set<int64> unmodified_transpose_dims;
+  for (const auto& pair : unmodified_dims) {
+    unmodified_transpose_dims.insert(pair.first);
+  }
+  lhs_contracting_dims.Clear();
+  for (int64 i = 0; i < transpose->shape().dimensions_size(); ++i) {
+    if (!unmodified_transpose_dims.contains(i)) {
+      lhs_contracting_dims.Add(i);
+    }
+  }
   lhs = lhs->mutable_operand(0);
 
-  // Check if transpose only permutes the contracting dims.
+  // Check that the transpose only permutes the contracting dims.
   const auto& transpose_dims = transpose->dimensions();
   for (int64 i = 0; i < transpose_dims.size(); ++i) {
     if (transpose_dims[i] != i &&
         !absl::c_linear_search(lhs_contracting_dims, i)) {
       return nullptr;
     }
   }
   // Virtually pull the transpose into the dot. Now the dot is equivalent to
diff --git a/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
--- a/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/google3/third_party/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -5297,10 +5297,57 @@ TEST_F(AlgebraicSimplifierTest, DotContr
       [](const Shape&, const Shape&) { return false; });
   options.set_is_layout_sensitive(true);
   // The transformation of moving transpose and reshape to the constant side is
   // layout insensitive. It should not happen if AlgebraicSimplifier is set up
   // to be layout sensitive.
   ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
 }
 
+TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDimsNoChange) {
+  // This isn't transformed (notice that the relative order of the `2` and `3`
+  // dims doesn't change, so there's no opportunity here), but it's nonetheless
+  // an interesting testcase because of the presence of the size-1 dimensions.
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+     param = f32[1,2,5,3] parameter(0)
+     transpose = f32[1,5,2,3] transpose(param), dimensions={0,2,1,3}
+     reshape = f32[5,6] reshape(transpose)
+     constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
+     ROOT dot = f32[5,4] dot(reshape, constant),
+       lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+}
+
+TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+     param = f32[1,2,3,5] parameter(0)
+     transpose = f32[1,3,2,5] transpose(param), dimensions={0,2,1,3}
+     reshape = f32[6,5] reshape(transpose)
+     constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
+     ROOT dot = f32[5,4] dot(reshape, constant),
+       lhs_contracting_dims={0}, rhs_contracting_dims={0}
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  auto shape1 = ShapeUtil::MakeShape(F32, {6, 5});
+  auto shape2 = ShapeUtil::MakeShape(F32, {1, 3, 2, 4});
+  auto shape3 = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
+  const HloInstruction* transpose;
+  ASSERT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Dot(
+                  m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
+                  m::Reshape(m::Transpose(&transpose,
+                                          m::Reshape(m::Constant())
+                                              .WithShapeCompatibleTo(&shape2))
+                                 .WithShapeCompatibleTo(&shape3)))));
+  EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
+}
+
 }  // namespace
 }  // namespace xla

@BinFan
Copy link
Contributor Author

BinFan commented Apr 30, 2019

@jlebar Thanks a lot for the patch! It looks good. And indeed I missed the size 1 dim case.

I'm wondering if we should add check after filling in unmodified_transpose_dims something like

if (!is_iota(unmodified_transpose_dims)) { // I did not find an std or absl library for is_iota
  return nullptr;
}

because I'm thinking of this example

  param = f32[2,5,1,3] parameter(0)
  transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
  reshape = f32[5,6] reshape(transpose)
  constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
  ROOT dot = f32[5,4] dot(reshape, constant),
    lhs_contracting_dims={1}, rhs_contracting_dims={0}

I think this example would pass all the check: After pulling in reshape, lhs_contracting_dims={0,2,3}, and the transpose only permute dimensions 0 and 2. But the relative order of dim 2 and 3 does not change either, so should be no opportunity here.

@jlebar
Copy link
Contributor

jlebar commented Apr 30, 2019

I think this example would pass all the check

This one does not trigger the transformation. I didn't step through in a debugger, but I think it's because the 1 dim is considered a contracting dim.

I added this one as a testcase.

@BinFan
Copy link
Contributor Author

BinFan commented Apr 30, 2019

Looks good to me. Thanks!

tensorflow-copybara pushed a commit that referenced this pull request Apr 30, 2019
tensorflow-copybara pushed a commit that referenced this pull request May 1, 2019
No functional change.

Relevant to PR #28170.

PiperOrigin-RevId: 246051968
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes comp:xla XLA ready to pull PR ready for merge process size:L CL Change Size: Large
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

None yet

6 participants